Hierarchical models

Theory: what is a hierarchical model?

In general: a model with hyperparameters, i.e. parameters that probabilistically control other parameters.

E.g.

\[\begin{align*} y_i &\sim N(\alpha_{group(i)} + \beta \cdot x_i, \sigma) \\ \alpha{group(i)} &\sim N(\mu, \tau) \end{align*}\]

In this model \(\tau\) is a hyperparameter. Is \(\mu\) a hyperparameter???

Hierarchical models are great for describing the situation where you know some measurements have something in common (e.g. they come from the same group), but you don’t know how much.

Learn more!

Example: always be closing!

Plushycorp employs 10 salespeople who go door to door selling cute plushies. The number of plushies that each salesperson sold every working day for two weeks was recorded. What can Plushycorp find out from this data?

To answer the question in a best-case scenario, we can use a hierarchical model to run a “digital twin” of this experiment with known parameters and data generating process. Specifically, we can assume that the number \(y_{ij}\) of plushies that salesperson \(i\) sells on day \(j\) depends on a combination of factors:

  • The baseline amount \(\mu\) that a totally average salesperson would sell on a normal day
  • The salesperson’s ability \(ability_i\)
  • An effect \(day\ effect_j\) for the day of the week: people are thought to buy fewer and fewer plushies as the week drags on.
  • Some random variation

A good first step for modelling count data is the Poisson distribution, so let’s assume that the sales measurements follow the following Poisson distribution:1

1 Note the use of the log link function.

\[\begin{align*} y_{ij} &\sim Poisson(\lambda) \\ \ln\lambda &= \mu + ability_i + day\ effect_j \end{align*}\]

We know that the salespeople have different abilities, but how just different are they? Since this isn’t really clear to Plushycorp, it makes sense to introduce a parameter \(\tau_{ability}\) into the model:

\[\begin{equation*} ability \sim N(0, \tau^{ability}) \end{equation*}\]

Now we have a hierarchical model!

We can make a similar argument for the day of the week effects:2

2 Can you think of a better model for day effects given the information above??

\[\begin{equation*} day\ effect \sim N(0, \tau^{day}) \end{equation*}\]

Finally we can complete our model by specifying prior distributions for the non-hierarchical parameters:3

3 \(HN\) here refers to the “half-normal” distribution, a decent default prior for hierarchical standard deviations

\[\begin{align*} \mu &\sim LN(0, 1) \\ \tau_ability &\sim HN(0, 1) \\ \tau_day &\sim HN(0, 1) \end{align*}\]

To test out our model with fake data, we can use Python to generate a fake set of salespeople and days, then generate some sales consistently with our model. Next we can generate some data,

from pathlib import Path
import json
import numpy as np
import pandas as pd

N_SALESPERSON = 10
N_WEEK = 2
DAY_NAMES = ["Mon", "Tue", "Wed", "Thu", "Fri"]
BASELINE = 2  # 2 plushies in one day is fine
TAU_ABILITY = 0.35
TAU_DAY = 0.2

SEED = 12345
DATA_DIR = Path("../data")

rng = np.random.default_rng(seed=SEED)

with open(DATA_DIR / "names.json", "r") as f:
    name_directory = json.load(f)

names = [
    f"{first_name} {surname}"
    for first_name, surname in zip(
        *map(
            lambda l: rng.choice(l, size=N_SALESPERSON, replace=False),
            name_directory.values()
        )
    )
]

abilities = rng.normal(loc=0, scale=TAU_ABILITY, size=N_SALESPERSON)

salespeople = pd.DataFrame({"salesperson": names, "ability": abilities})

salespeople
salesperson ability
0 Morten Andersen 0.489643
1 Lene Poulsen 0.462804
2 Rasmus Jensen -0.104894
3 Hanne Madsen 0.316022
4 Mette Rasmussen -0.567554
5 Christian Christensen -0.055366
6 Helle Kristensen 0.157319
7 Charlotte Hansen -0.470260
8 Maria Petersen -0.028591
9 Jette Thomsen 0.603659
day_effects = sorted(
    rng.normal(loc=0, scale=TAU_DAY, size=len(DAY_NAMES))
)[::-1]  # This (i.e. `[::-1]`) is a nice way to reverse a list
days = pd.DataFrame({"day": DAY_NAMES, "day_effect": day_effects})
days
day day_effect
0 Mon 0.523632
1 Tue 0.165727
2 Wed 0.155472
3 Thu -0.191798
4 Fri -0.241878
sales = (
    days
    .merge(salespeople, how="cross")
    .merge(pd.DataFrame({"week":[1, 2, 3, 4]}), how="cross")
    .assign(
        sales=lambda df: rng.poisson(
            np.exp(np.log(BASELINE) + df["ability"] + df["day_effect"])
        )
    )
    [["week", "day", "salesperson", "day_effect", "ability", "sales"]]
    .copy()
)
sales.head()
week day salesperson day_effect ability sales
0 1 Mon Morten Andersen 0.523632 0.489643 10
1 2 Mon Morten Andersen 0.523632 0.489643 3
2 3 Mon Morten Andersen 0.523632 0.489643 4
3 4 Mon Morten Andersen 0.523632 0.489643 4
4 1 Mon Lene Poulsen 0.523632 0.462804 4

Here is the fortnightly sales chart

total_sales = (
    sales.groupby("salesperson")["sales"].sum().sort_values(ascending=False)
)

total_sales.plot(kind="bar", ylabel="Plushies sold", title="Fortnightly sales")

It’s pretty straightforward to represent hierarchical models with Stan, almost like Stan was designed for it!

from cmdstanpy import CmdStanModel

model = CmdStanModel(stan_file="../src/stan/plushies.stan")
print(model.code())
data {
 int<lower=1> N;
 int<lower=1> N_salesperson;
 int<lower=1> N_day;
 array[N] int<lower=1,upper=N_salesperson> salesperson;
 array[N] int<lower=1,upper=N_day> day;
 array[N] int<lower=0> sales;
 int<lower=0,upper=1> likelihood;
}
parameters {
 real log_mu;
 vector[N_salesperson] ability;
 vector[N_day] day_effect;
 real<lower=0> tau_ability;
 real<lower=0> tau_day;
}
transformed parameters {
 vector[N] log_lambda = log_mu + ability[salesperson] + day_effect[day]; 
}
model {
  log_mu ~ normal(0, 1);
  ability ~ normal(0, tau_ability);
  day_effect ~ normal(0, tau_day);
  tau_ability ~ normal(0, 0.5);
  tau_day ~ normal(0, 0.5);
  if (likelihood){
    sales ~ poisson_log(log_lambda);
  }
}
generated quantities {
 real mu = exp(log_mu);
 vector[N] lambda = exp(log_lambda);
 array[N] int yrep = poisson_rng(lambda);
 vector[N] llik; 
 for (n in 1:N){
   llik[n] = poisson_lpmf(sales[n] | lambda[n]);
 }
}

import arviz as az
from stanio.json import process_dictionary

def one_encode(l):
    """One-encode a 1d list-like thing."""
    return dict(zip(l, range(1, len(l) + 1)))


salesperson_codes = one_encode(salespeople["salesperson"])
day_codes = one_encode(days["day"])
data_prior = process_dictionary({
        "N": len(sales),
        "N_salesperson": len(salespeople),
        "N_day": len(days),
        "salesperson": sales["salesperson"].map(salesperson_codes),
        "day": sales["day"].map(day_codes),
        "sales": sales["sales"],
        "likelihood": 0
    }
)
data_posterior = data_prior | {"likelihood": 1}
mcmc_prior = model.sample(data=data_prior)
mcmc_posterior = model.sample(data=data_posterior)
idata = az.from_cmdstanpy(
    posterior=mcmc_posterior,
    prior=mcmc_prior,
    log_likelihood="llik",
    posterior_predictive="yrep",
    observed_data=data_posterior,
    coords={
        "salesperson": salespeople["salesperson"],
        "day": days["day"],
        "observation": sales.index
    },
    dims={
        "lambda": ["observation"],
        "ability": ["salesperson"],
        "day_effect": ["day"],
        "llik": ["observation"],
        "yrep": ["observation"]
    }
)
idata
09:50:55 - cmdstanpy - INFO - CmdStan start processing
09:50:55 - cmdstanpy - INFO - CmdStan done processing.
09:50:55 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in 'plushies.stan', line 22, column 2 to column 35)
Consider re-running with show_console=True if the above output is unclear!
09:50:56 - cmdstanpy - WARNING - Some chains may have failed to converge.
    Chain 1 had 5 divergent transitions (0.5%)
    Chain 2 had 16 divergent transitions (1.6%)
    Chain 3 had 6 divergent transitions (0.6%)
    Chain 4 had 27 divergent transitions (2.7%)
    Use the "diagnose()" method on the CmdStanMCMC object to see further information.
09:50:56 - cmdstanpy - INFO - CmdStan start processing
09:50:56 - cmdstanpy - INFO - CmdStan done processing.
09:50:56 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in 'plushies.stan', line 22, column 2 to column 35)
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in 'plushies.stan', line 23, column 2 to column 34)
Consider re-running with show_console=True if the above output is unclear!
                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                
arviz.InferenceData
    • <xarray.Dataset> Size: 13MB
      Dimensions:           (chain: 4, draw: 1000, salesperson: 10, day: 5,
                             log_lambda_dim_0: 200, observation: 200)
      Coordinates:
        * chain             (chain) int64 32B 0 1 2 3
        * draw              (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
        * salesperson       (salesperson) object 80B 'Morten Andersen' ... 'Jette T...
        * day               (day) object 40B 'Mon' 'Tue' 'Wed' 'Thu' 'Fri'
        * log_lambda_dim_0  (log_lambda_dim_0) int64 2kB 0 1 2 3 4 ... 196 197 198 199
        * observation       (observation) int64 2kB 0 1 2 3 4 ... 195 196 197 198 199
      Data variables:
          log_mu            (chain, draw) float64 32kB 0.5999 0.5902 ... 0.6438 0.7046
          ability           (chain, draw, salesperson) float64 320kB 0.5359 ... 0.6385
          day_effect        (chain, draw, day) float64 160kB 0.3391 ... -0.2967
          tau_ability       (chain, draw) float64 32kB 0.4202 0.4369 ... 0.282 0.5216
          tau_day           (chain, draw) float64 32kB 0.3009 0.2161 ... 0.2873 0.3089
          log_lambda        (chain, draw, log_lambda_dim_0) float64 6MB 1.475 ... 1...
          mu                (chain, draw) float64 32kB 1.822 1.804 ... 1.904 2.023
          lambda            (chain, draw, observation) float64 6MB 4.37 4.37 ... 2.847
      Attributes:
          created_at:                 2024-04-16T07:50:56.938199
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 6MB
      Dimensions:      (chain: 4, draw: 1000, observation: 200)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
        * observation  (observation) int64 2kB 0 1 2 3 4 5 ... 194 195 196 197 198 199
      Data variables:
          yrep         (chain, draw, observation) float64 6MB 7.0 2.0 6.0 ... 3.0 4.0
      Attributes:
          created_at:                 2024-04-16T07:50:56.944208
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 6MB
      Dimensions:      (chain: 4, draw: 1000, observation: 200)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
        * observation  (observation) int64 2kB 0 1 2 3 4 5 ... 194 195 196 197 198 199
      Data variables:
          llik         (chain, draw, observation) float64 6MB -4.726 -1.738 ... -1.448
      Attributes:
          created_at:                 2024-04-16T07:50:57.279223
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          lp               (chain, draw) float64 32kB 2.438 1.211 ... 3.896 4.063
          acceptance_rate  (chain, draw) float64 32kB 0.939 0.9959 ... 0.8373 0.995
          step_size        (chain, draw) float64 32kB 0.1937 0.1937 ... 0.1834 0.1834
          tree_depth       (chain, draw) int64 32kB 4 4 4 4 4 4 4 4 ... 4 4 4 4 4 4 4
          n_steps          (chain, draw) int64 32kB 15 15 15 31 15 ... 15 15 15 15 15
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 4.797 3.144 4.33 ... 5.575 1.984
      Attributes:
          created_at:                 2024-04-16T07:50:56.941898
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 26MB
      Dimensions:           (chain: 4, draw: 1000, salesperson: 10, day: 5,
                             log_lambda_dim_0: 200, observation: 200)
      Coordinates:
        * chain             (chain) int64 32B 0 1 2 3
        * draw              (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
        * salesperson       (salesperson) object 80B 'Morten Andersen' ... 'Jette T...
        * day               (day) object 40B 'Mon' 'Tue' 'Wed' 'Thu' 'Fri'
        * log_lambda_dim_0  (log_lambda_dim_0) int64 2kB 0 1 2 3 4 ... 196 197 198 199
        * observation       (observation) int64 2kB 0 1 2 3 4 ... 195 196 197 198 199
      Data variables:
          log_mu            (chain, draw) float64 32kB 1.804 0.7837 ... 0.5279 1.386
          ability           (chain, draw, salesperson) float64 320kB -0.8463 ... 0....
          day_effect        (chain, draw, day) float64 160kB -0.09451 ... -0.1867
          tau_ability       (chain, draw) float64 32kB 0.5579 0.5346 ... 0.6569 0.2709
          tau_day           (chain, draw) float64 32kB 0.1031 0.127 ... 0.6188 0.4901
          log_lambda        (chain, draw, log_lambda_dim_0) float64 6MB 0.8629 ... ...
          mu                (chain, draw) float64 32kB 6.072 2.189 ... 1.695 3.998
          lambda            (chain, draw, observation) float64 6MB 2.37 2.37 ... 3.473
          yrep              (chain, draw, observation) float64 6MB 3.0 2.0 ... 5.0 2.0
          llik              (chain, draw, observation) float64 6MB -8.845 ... -1.676
      Attributes:
          created_at:                 2024-04-16T07:50:57.272501
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          lp               (chain, draw) float64 32kB 5.297 5.612 ... -3.026 4.468
          acceptance_rate  (chain, draw) float64 32kB 0.9963 0.9418 ... 0.9603 0.9464
          step_size        (chain, draw) float64 32kB 0.1406 0.1406 ... 0.1871 0.1871
          tree_depth       (chain, draw) int64 32kB 5 5 5 5 5 5 5 5 ... 5 4 4 5 4 5 4
          n_steps          (chain, draw) int64 32kB 63 31 63 31 31 ... 15 31 31 31 31
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 4.862 2.485 3.01 ... 9.386 9.497
      Attributes:
          created_at:                 2024-04-16T07:50:57.275655
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 10kB
      Dimensions:              (N_dim_0: 1, N_salesperson_dim_0: 1, N_day_dim_0: 1,
                                salesperson_dim_0: 200, day_dim_0: 200,
                                sales_dim_0: 200, likelihood_dim_0: 1)
      Coordinates:
        * N_dim_0              (N_dim_0) int64 8B 0
        * N_salesperson_dim_0  (N_salesperson_dim_0) int64 8B 0
        * N_day_dim_0          (N_day_dim_0) int64 8B 0
        * salesperson_dim_0    (salesperson_dim_0) int64 2kB 0 1 2 3 ... 197 198 199
        * day_dim_0            (day_dim_0) int64 2kB 0 1 2 3 4 ... 195 196 197 198 199
        * sales_dim_0          (sales_dim_0) int64 2kB 0 1 2 3 4 ... 196 197 198 199
        * likelihood_dim_0     (likelihood_dim_0) int64 8B 0
      Data variables:
          N                    (N_dim_0) int64 8B 200
          N_salesperson        (N_salesperson_dim_0) int64 8B 10
          N_day                (N_day_dim_0) int64 8B 5
          salesperson          (salesperson_dim_0) int64 2kB 1 1 1 1 2 ... 10 10 10 10
          day                  (day_dim_0) int64 2kB 1 1 1 1 1 1 1 1 ... 5 5 5 5 5 5 5
          sales                (sales_dim_0) int64 2kB 10 3 4 4 4 5 6 ... 4 1 1 3 3 2
          likelihood           (likelihood_dim_0) int64 8B 1
      Attributes:
          created_at:                 2024-04-16T07:50:57.277682
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

az.summary(idata, var_names="~lambda", filter_vars="regex")
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
log_mu 0.777 0.222 0.336 1.180 0.009 0.006 688.0 744.0 1.00
ability[Morten Andersen] 0.438 0.172 0.130 0.772 0.005 0.003 1406.0 1748.0 1.00
ability[Lene Poulsen] 0.374 0.168 0.084 0.711 0.005 0.003 1402.0 1737.0 1.00
ability[Rasmus Jensen] -0.168 0.189 -0.511 0.193 0.004 0.003 1799.0 2023.0 1.00
ability[Hanne Madsen] 0.102 0.178 -0.228 0.446 0.005 0.003 1536.0 2169.0 1.00
ability[Mette Rasmussen] -0.545 0.207 -0.929 -0.150 0.005 0.003 2041.0 2380.0 1.00
ability[Christian Christensen] -0.125 0.183 -0.471 0.221 0.004 0.003 1688.0 2305.0 1.00
ability[Helle Kristensen] 0.185 0.178 -0.152 0.519 0.005 0.003 1497.0 2106.0 1.00
ability[Charlotte Hansen] -0.282 0.193 -0.649 0.083 0.005 0.003 1752.0 2051.0 1.00
ability[Maria Petersen] -0.210 0.189 -0.561 0.155 0.005 0.003 1713.0 1955.0 1.00
ability[Jette Thomsen] 0.411 0.171 0.112 0.751 0.005 0.003 1439.0 1735.0 1.00
day_effect[Mon] 0.342 0.192 -0.021 0.691 0.007 0.005 902.0 678.0 1.01
day_effect[Tue] 0.124 0.191 -0.256 0.452 0.007 0.006 877.0 581.0 1.01
day_effect[Wed] 0.132 0.193 -0.229 0.493 0.007 0.006 930.0 724.0 1.01
day_effect[Thu] -0.302 0.198 -0.716 0.044 0.007 0.005 972.0 828.0 1.01
day_effect[Fri] -0.171 0.191 -0.535 0.179 0.006 0.005 973.0 614.0 1.01
tau_ability 0.398 0.114 0.207 0.610 0.002 0.002 2296.0 2722.0 1.00
tau_day 0.345 0.152 0.124 0.640 0.004 0.003 1930.0 1762.0 1.00
mu 2.227 0.488 1.346 3.157 0.018 0.013 688.0 744.0 1.00

The problem with hierarchical models: funnels

Did you notice that cmdstanpy printed some divergent transition warnings above? This illustrates a pervasive problem with hierarchical models: funnel-shaped marginal posterior distributions. The plot below shows the values of the parameter \(\tau_{day}\) and the corresponding day effect values for Monday in the prior samples:

az.plot_pair(
    idata.prior,
    var_names=["tau_day", "day_effect"],
    coords={"day": ["Mon"]},
);
/Users/tedgro/repos/biosustain/bayesian_statistics_for_computational_biology/.venv/lib/python3.12/site-packages/arviz/plots/pairplot.py:232: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.
  gridsize = int(dataset.dims["draw"] ** 0.35)

As we discussed previously, funnels are hard to sample because of their inconsistent characteristic lengths. Unfortunately, they are often inevitable in hierarchical models. Do you get an idea why from the graph?

There are three main solutions to funnels: add more information, tune the HMC algorithm or reparameterise the model.

Add more information

The posterior distribution didn’t have any divergent transitions. This is probably because the extra information in the measurements made it easier to sample. Comparing the marginal distributions from above illustrates how this can happen: note that the difference in scale between the neck and the bowl of the funnel is less extreme for the posterior samples.

from matplotlib import pyplot as plt
f, ax = plt.subplots()
az.plot_pair(
    idata.prior,
    var_names=["tau_day", "day_effect"],
    coords={"day": ["Mon"]},
    ax=ax,
    scatter_kwargs={"label": "prior"},
);
az.plot_pair(
    idata.posterior,
    var_names=["tau_day", "day_effect"],
    coords={"day": ["Mon"]},
    ax=ax,
    scatter_kwargs={"label": "posterior"},
);
ax.legend(frameon=False);
/Users/tedgro/repos/biosustain/bayesian_statistics_for_computational_biology/.venv/lib/python3.12/site-packages/arviz/plots/pairplot.py:232: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.
  gridsize = int(dataset.dims["draw"] ** 0.35)
/Users/tedgro/repos/biosustain/bayesian_statistics_for_computational_biology/.venv/lib/python3.12/site-packages/arviz/plots/pairplot.py:232: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.
  gridsize = int(dataset.dims["draw"] ** 0.35)

If better measurements aren’t available, divergences can often be avoided by searching for extra information that can justify narrower priors.

Tune the algorithm

Stan allows increasing the length of the warmup phase (iter_warmup, default 2000), bringing the target acceptance probability close to 1 (adapt_delta, default 0.8) and by increasing the leapfrog integrator’s maximum tree depth (max_treedepth, default 10). All of these changes trade speed for reliability.

mcmc_prior_2 = model.sample(
    data=data_prior,
    iter_warmup=3000,
    adapt_delta=0.99,
    max_treedepth=12
)
09:50:57 - cmdstanpy - INFO - CmdStan start processing
09:51:01 - cmdstanpy - INFO - CmdStan done processing.
09:51:01 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in 'plushies.stan', line 23, column 2 to column 34)
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in 'plushies.stan', line 22, column 2 to column 35)
Consider re-running with show_console=True if the above output is unclear!
                                                                                                                                                                                                                                                                                                                                

Unfortunately even quite aggressive tuning doesn’t get rid of all the divergent transitions in this case.

Reparameterise

The idea with reparameterisation is to define auxiliary parameters which don’t have problematic relationships, then recover the problematic parameters later.

“Non-centred” parameterisations take a distribution with the form \(\alpha\sim D(\mu,\sigma)\) and express it as follows:

\[\begin{align*} u \sim D(0, 1)\\ \alpha = \mu + u * \sigma \end{align*}\]

model_nc = CmdStanModel(stan_file="../src/stan/plushies-nc.stan")
print(model_nc.code())
data {
 int<lower=1> N;
 int<lower=1> N_salesperson;
 int<lower=1> N_day;
 array[N] int<lower=1,upper=N_salesperson> salesperson;
 array[N] int<lower=1,upper=N_day> day;
 array[N] int<lower=0> sales;
 int<lower=0,upper=1> likelihood;
}
parameters {
 real log_mu;
 vector[N_salesperson] ability_z;
 vector[N_day] day_effect_z;
 real<lower=0> tau_ability;
 real<lower=0> tau_day;
}
transformed parameters {
 vector[N_salesperson] ability = ability_z * tau_ability;
 vector[N_day] day_effect = day_effect_z * tau_day;
 vector[N] log_lambda = log_mu + ability[salesperson] + day_effect[day]; 
}
model {
  log_mu ~ normal(0, 1);
  ability_z ~ normal(0, 1);
  day_effect_z ~ normal(0, 1);
  tau_ability ~ normal(0, 1);
  tau_day ~ normal(0, 1);
  if (likelihood){
    sales ~ poisson_log(log_lambda);
  }
}
generated quantities {
 real mu = exp(log_mu);
 vector[N] lambda = exp(log_lambda);
 array[N] int yrep = poisson_rng(lambda);
 vector[N] llik; 
 for (n in 1:N){
   llik[n] = poisson_lpmf(sales[n] | lambda[n]);
 }
}
mcmc_prior_nc = model.sample(
    data=data_prior,
    iter_warmup=3000,
    adapt_delta=0.999,
    max_treedepth=12
)
09:51:02 - cmdstanpy - INFO - CmdStan start processing
09:51:17 - cmdstanpy - INFO - CmdStan done processing.
09:51:17 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in 'plushies.stan', line 23, column 2 to column 34)
Consider re-running with show_console=True if the above output is unclear!
09:51:17 - cmdstanpy - WARNING - Some chains may have failed to converge.
    Chain 1 had 15 iterations at max treedepth (1.5%)
    Chain 3 had 3 divergent transitions (0.3%)
    Use the "diagnose()" method on the CmdStanMCMC object to see further information.
                                                                                                                                                                                                                                                                                                                                

Beware of using non-centred parameterisation as a default: it isn’t guaranteed to be better.

So how many plushies do I need to sell?

f, ax = plt.subplots()
az.plot_forest(
    np.exp(idata.posterior["log_mu"] + idata.posterior["ability"]),
    kind="forestplot",
    combined=True,
    ax=ax,
    show=False,
);
ax.scatter(
    np.exp(np.log(BASELINE) + salespeople["ability"]), 
    ax.get_yticks()[::-1], 
    color="red", 
    label="True expected sales",
    zorder=2
)
ax.scatter(
    sales.groupby("salesperson")["sales"].mean().reindex(salespeople["salesperson"]), 
    ax.get_yticks()[::-1], 
    color="black", 
    label="Observed sales per day",
    zorder=3
)
ax.set(title="", xlabel="Number of plushies sold per day")
ax.axvline(BASELINE, linestyle="--", label="baseline", linewidth=0.8, color="black")
ax.legend(frameon=False);

Takeaways

  • Hierarchical models are a powerful way to capture structural information
  • You may run into problematic sampling, but you have options!
  • There is surprisingly little information in low-expected-value count data.